{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# 添加ONNXToMindSpore算子映射关系高级教程\n",
    "\n",
    "`Linux` `Ascend` `GPU` `CPU` `模型迁移` `高级`\n",
    "\n",
    "[![](https://gitee.com/mindspore/docs/raw/master/resource/_static/logo_source.png)](https://gitee.com/mindspore/mindinsight/blob/master/ecosystem_tools/mindconverter/tutorial/add_onnx2mindspore_operator_mapper_advanced_tutorial.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 概述\n",
    "\n",
    "在确定ONNX算子到MindSpore算子的映射关系时，会遇到两者之间不存在相似实现或者参数差异过大难以直接转换的算子的问题。本文将在[初级教程](https://gitee.com/mindspore/mindinsight/blob/master/ecosystem_tools/mindconverter/tutorial/add_onnx2mindspore_operator_mapper_base_tutorial.ipynb)的基础上，以该类算子映射关系为例，来描述添加算子映射关系文件的方法。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 环境准备\n",
    "\n",
    "本案例需安装以下Python三方库：\n",
    "\n",
    "```bash\n",
    "pip install mindspore==1.6.0\n",
    "pip install mindconverter==1.6.0\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 自定义添加算子映射脚本\n",
    "\n",
    "以`onnx::AveragePool`算子为例进行演示。\n",
    "\n",
    "分别查阅[ONNX算子API文档](https://github.com/onnx/onnx/blob/master/docs/Operators.md)和[MindSpore算子API文档](https://www.mindspore.cn/docs/zh-CN/master/index.html)，\n",
    "找到与ONNX算子`onnx::AveragePool`功能相同或相近的MindSpore算子`mindspore.nn.AvgPool2d`。\n",
    "\n",
    "|算子名|`onnx::AveragePool`|`mindspore.nn.AvgPool2d`|\n",
    "|:----:|:----|:----|\n",
    "|算法实现|`output_shape[i] = floor((input_shape[i]+pad_shape[i]-kernel_shape[i])/strides_shape[i])`<br>OR<br>`output_shape[i] = ceil((input_shape[i]+pad_shape[i]-kernel_shape[i])/strides_shape[i])` based on `ceil_mode`|`output_shape[i] = ceil((input_shape[i]-kernel_size[i]+1)/stride_shape[i])`<br>OR<br>`output_shape[i] = ceil(input_shape[i]/stride_shape[i])` based on `pad_mode`|\n",
    "|参数|`auto_pad`: DEPRECATED<br>`ceil_mode`: optional<br>`count_include_pad`: optional<br>`kernel_shape`: optional<br>`pads`: optional<br>`strides`: optional|`kernel_size`: optional<br>`stride`: optional<br>`pad_mode`: optional<br>`data_format`: optional<br>|\n",
    "|输入|`X`: required|`input`: required|\n",
    "|输出|`Y`|`output`|\n",
    "\n",
    "<br>\n",
    "依据双方算子中参数（Attributes/Parameters）和输入（Inputs）进行ONNX到MindSpore的算子映射。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import math\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper\n",
    "from mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords\n",
    "\n",
    "\n",
    "class PoolMapper(ONNXToMindSporeMapper):\n",
    "    \"\"\"Pool mapper.\"\"\"\n",
    "\n",
    "    @staticmethod\n",
    "    def _operation_name_in_ms(*args, **kwargs):\n",
    "        if kwargs['op_name'] == 'onnx::AveragePool':\n",
    "            op_name = 'nn.AvgPool{}d'\n",
    "        else:\n",
    "            op_name = 'nn.MaxPool{}d'\n",
    "        dim = len(kwargs['params']['strides'])\n",
    "        return op_name.format(dim)\n",
    "\n",
    "    @staticmethod\n",
    "    def _convert_params(**kwargs):\n",
    "        params = kwargs['params']\n",
    "        transformed_params = dict()\n",
    "        transformed_params[\"kernel_size\"] = tuple(params['kernel_shape'])\n",
    "        transformed_params[\"stride\"] = tuple(params['strides'])\n",
    "\n",
    "        return transformed_params\n",
    "\n",
    "    @staticmethod\n",
    "    def _convert_trained_weights(**kwargs):\n",
    "        return dict()\n",
    "\n",
    "    @staticmethod\n",
    "    def _get_ms_opt_shape(**kwargs):\n",
    "        \"\"\"用于计算MindSpore算子在使用ONNX参数时，由`input_shape`得到的`output_shape`。\"\"\"\n",
    "        params = kwargs['raw_params']\n",
    "        input_shape = params['input_shape']\n",
    "        kernel_shape = params['kernel_shape']\n",
    "        strides = params['strides']\n",
    "        dilations = params.get('dilations', (1, 1))\n",
    "        ms_opt_shape = np.true_divide(np.subtract(np.array(input_shape[-len(kernel_shape):], dtype=np.float32),\n",
    "                                                  ((np.array(kernel_shape, dtype=np.float32) - 1) *\n",
    "                                                   np.array(dilations, dtype=np.float32) + 1)) + 1,\n",
    "                                      np.array(strides, dtype=np.float32)).tolist()\n",
    "        ms_opt_shape_ceil = tuple(math.ceil(ms_opt_shape_axis) for ms_opt_shape_axis in ms_opt_shape)\n",
    "        return ms_opt_shape_ceil\n",
    "\n",
    "    @staticmethod\n",
    "    def _generate_snippet_template(**kwargs):\n",
    "        \"\"\"\n",
    "        对于无法直接使用`_convert_params`方法进行参数映射的算子，重写此方法通过自定义的模板\n",
    "        来生成算子在转换脚本中的定义（`init`）和调用（`construct`）。\n",
    "\n",
    "        Args:\n",
    "            operation (str): MindSpore中的对应算子名。\n",
    "            converted_params (dict): 由`_convert_params`方法转换得到的MindSpore算子的参数。\n",
    "            raw_params (dict): ONNX算子的参数(`raw_params`)，`input_shape`和`output_shape`。\n",
    "        \"\"\"\n",
    "\n",
    "        op = kwargs.get(\"operation\")\n",
    "        args = kwargs.get(\"converted_params\", dict())\n",
    "\n",
    "        ms_opt_shape = PoolMapper._get_ms_opt_shape(**kwargs)\n",
    "        tensor_opt_shape = kwargs['raw_params']['output_shape']\n",
    "        tensor_ipt_shape = kwargs['raw_params']['input_shape']\n",
    "        kernel_shape = kwargs['raw_params']['kernel_shape']\n",
    "        dilations = kwargs['raw_params'].get('dilations', (1, 1))\n",
    "        strides = kwargs['raw_params']['strides']\n",
    "\n",
    "        if not op:\n",
    "            raise ValueError(\"Can not get MindSpore operation name.\")\n",
    "\n",
    "        # 定义生成代码的模板。`init_xx`是在`init`中的算子定义，`construct_xx`是在`construct`中的算子调用，\n",
    "        # 其中的`variable_slot`是替换用标签，会被后续的脚本生成模块填充。\n",
    "        variable_slot = \"var_0\"\n",
    "        init_template = f\"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})\"\n",
    "        construct_template = f\"opt_{{{variable_slot}}} = self.{{{variable_slot}}}(opt_{{{variable_slot}}})\"\n",
    "\n",
    "        # 由于该算子在ONNX和MindSpore中的实现差异较大，为了保证转换结果的一致性，需要添加`mindspore.nn.Pad`算子，\n",
    "        # 对输入进行处理之后，再传入算子中进行推理。\n",
    "        # 该方法的输出依次为`Pad`算子定义，`Pad`算子调用和`Pad`算子的参数`paddings`。\n",
    "        init_template_pad, construct_template_pad, paddings = \\\n",
    "            PoolMapper._generate_pad_init_and_construct(tensor_opt_shape, tensor_ipt_shape,\n",
    "                                                        ms_opt_shape, variable_slot,\n",
    "                                                        kernel_shape, dilations, strides)\n",
    "\n",
    "        # 返回给后续模块的生成模板数据体，将按照列表顺序依次生成算子定义和算子调用，\n",
    "        # `TemplateKeyWords.INIT.value`和`TemplateKeyWords.CONSTRUCT.value`分别表示`init`和`construct`。\n",
    "        template = {\n",
    "            variable_slot: {\n",
    "                TemplateKeywords.INIT.value: [init_template_pad, init_template],\n",
    "                TemplateKeywords.CONSTRUCT.value: [construct_template_pad, construct_template]\n",
    "            }\n",
    "        }\n",
    "\n",
    "        # 新添加算子`Pad`的参数`paddings`也作为算子`Pool`的参数进行返回，使该参数也能正确的进行设置。\n",
    "        args['paddings'] = paddings\n",
    "\n",
    "        # 用于与后续模块进行信息交换。\n",
    "        exchange_msg = {\n",
    "            variable_slot: {\n",
    "                ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op,  # MindSpore算子名。\n",
    "                ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None,  # 算子对应的变量名，由后续模块填写，此处为None。\n",
    "                ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value:\n",
    "                    ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value,  # 算子输出的类型，`mindspore.Tensor`或者`Tuple<mindspore.Tensor>`。\n",
    "                ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [],  # 算子输入，由后续模块填写，此处为list()。\n",
    "                ExchangeMessageKeywords.VariableScope.value.ARGS.value: args,  # 算子参数。\n",
    "                ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: dict(),  # 算子的权重信息。\n",
    "                ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: dict()  # 算子的可训练权重信息。由`_convert_trained_weights`方法返回。\n",
    "            }\n",
    "        }\n",
    "        # 算子输出的变量名。若为多输出，则按照列表顺序依次生成。\n",
    "        outputs_list = [f\"opt_{{{variable_slot}}}\"]\n",
    "        # ONNX算子和MindSpore算子输出的对应顺序，主要用于保证多输出算子输出拓扑序的一致性。\n",
    "        outputs_mapping = ((0, 0),)\n",
    "        return template, exchange_msg, outputs_list, outputs_mapping\n",
    "\n",
    "    @staticmethod\n",
    "    def _generate_pad_init_and_construct(tensor_opt_shape, tensor_ipt_shape,\n",
    "                                         ms_opt_shape, variable_slot, kernel_shape, dilations, strides):\n",
    "        \"\"\"\n",
    "        生成`Pad`算子定义语句，`Pad`算子调用语句和计算参数`paddings`。\n",
    "\n",
    "        Args:\n",
    "            tensor_opt_shape (tuple): ONNX算子输出尺寸。\n",
    "            tensor_ipt_shape (tuple): ONNX算子输入尺寸。\n",
    "       ms_opt_shape (tuple): MindSpore算子输出尺寸。\n",
    "            variable_slot (str): 用于后续模块进行替换的标识符。\n",
    "            kernel_shape (Union[tuple, int]): ONNX算子参数`kernel_shape`。\n",
    "            dilations (Union[tuple, int]): ONNX算子参数`dilations`。\n",
    "            strides (Union[tuple, int]): ONNX算子参数`strides`。\n",
    "        \"\"\"\n",
    "\n",
    "        onnx_opt_shape = tensor_opt_shape[-len(ms_opt_shape):]\n",
    "        onnx_ipt_shape = tensor_ipt_shape[-len(ms_opt_shape):]\n",
    "\n",
    "        if np.any(np.array(ms_opt_shape) > np.array(onnx_opt_shape)):\n",
    "            raise ValueError(f\"ms_opt_shape[{ms_opt_shape}] should be no larger than onnx_opt_shape[{onnx_opt_shape}].\")\n",
    "\n",
    "        if np.all(np.array(ms_opt_shape) == np.array(onnx_opt_shape)):\n",
    "            shape_diff = np.zeros(len(ms_opt_shape)).astype(np.int).tolist()\n",
    "        else:\n",
    "            shape_diff = np.subtract((np.array(onnx_opt_shape) - 1) * np.array(strides),\n",
    "                                     np.subtract(np.array(onnx_ipt_shape),\n",
    "                                                 (np.array(kernel_shape) - 1) * np.array(dilations) + 1)).tolist()\n",
    "\n",
    "        zero_pad_single = (0, 0)\n",
    "        paddings = [zero_pad_single]\n",
    "        num_zero_pads = len(tensor_opt_shape) - len(ms_opt_shape)\n",
    "        for _ in range(num_zero_pads - 1):\n",
    "            paddings.append(zero_pad_single)\n",
    "\n",
    "        for axis_diff in shape_diff:\n",
    "            paddings.append((int(axis_diff // 2), int(axis_diff // 2 + axis_diff % 2)))\n",
    "\n",
    "        init_template_pad = f\"self.pad_{{{variable_slot}}} = nn.Pad(paddings={{paddings}})\"\n",
    "        construct_template_pad = f\"opt_{{{variable_slot}}} = self.pad_{{{variable_slot}}}\" \\\n",
    "                                 f\"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}})\"\n",
    "\n",
    "        return init_template_pad, construct_template_pad, tuple(paddings)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "将该Mapper脚本命名为`pool_mapper.py`，该命名方式需要和类名（`PoolMapper`）相对应。<br>\n",
    "并放入 `mindconverter/graph_based_converter/mapper/impl/nn`目录下，该放置目录需要根据对应的MindSpore算子所在的层（`nn`/`ops`）来设置。<br>\n",
    "最后修改 `mindconverter/graph_based_converter/mapper/onnx_to_ms.json`，\n",
    "添加 `\"onnx::AveragePool\": \"mindconverter.graph_based_converter.mapper.impl.nn.pool_mapper.PoolMapper\"`来确定ONNX算子所对应的Mapper脚本文件。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 验证自定义算子映射脚本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper\n",
    "from mindconverter.graph_based_converter.common.code_fragment import Fragment\n",
    "\n",
    "\n",
    "def test_mapper(onnx_info):\n",
    "    \"\"\"\n",
    "    Test mapper.\n",
    "\n",
    "    Args:\n",
    "        onnx_info (dict): Onnx operator_info. Struct is\n",
    "                                   {\n",
    "                                    'op_name': op_name,\n",
    "                                    'attributes': dict(),\n",
    "                                    'weights': [NodeWeight(), ...]\n",
    "                                   }\n",
    "    \"\"\"\n",
    "\n",
    "    template, exchange_msg, outputs_lists, outputs_mapping = \\\n",
    "        ONNXToMindSporeMapper.convert(onnx_info['op_name'],\n",
    "                                      onnx_info['attributes'],\n",
    "                                      onnx_info['weights'])\n",
    "\n",
    "    exchange_msg['var_0']['variable_name'] = 'self_defined_operator'\n",
    "    exchange_msg['var_0']['inputs'] = ['x']\n",
    "\n",
    "    fragment = Fragment(data_entity=exchange_msg, code_template=template, outputs=outputs_lists,\n",
    "                        outputs_mapping=outputs_mapping)\n",
    "\n",
    "    code = fragment()\n",
    "    init_code = code[0]\n",
    "    construct_code = code[1]\n",
    "    print('-'*30, 'init_code', '-'*30)\n",
    "    print('\\n'.join(init_code))\n",
    "    print('-'*30, 'construct_code', '-'*30)\n",
    "    print('\\n'.join(construct_code))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------------------------------ init_code ------------------------------\n",
      "self.pad_self_defined_operator = nn.Pad(paddings=((0, 0), (0, 0), (1, 2), (1, 2)))\n",
      "self.self_defined_operator = nn.AvgPool2d(kernel_size=(5, 5), stride=(2, 2))\n",
      "------------------------------ construct_code ------------------------------\n",
      "opt_self_defined_operator = self.pad_self_defined_operator(x)\n",
      "opt_self_defined_operator = self.self_defined_operator(opt_self_defined_operator)\n"
     ]
    }
   ],
   "source": [
    "onnx_operator_info = {'op_name': 'onnx::AveragePool',\n",
    "                      'attributes': {'auto_pad': 'NOTSET',\n",
    "                                     'ceil_mode': 0,\n",
    "                                     'count_include_pad': 0,\n",
    "                                     'kernel_shape': (5, 5),\n",
    "                                     'pads': (0, 0, 0, 0),\n",
    "                                     'strides': (2, 2),\n",
    "                                     'input_shape': (1, 3, 224, 224),\n",
    "                                     'output_shape': (1, 3, 112, 112)\n",
    "                                    },\n",
    "                      'weights': []}\n",
    "test_mapper(onnx_operator_info)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 权重迁移相关教程\n",
    "\n",
    "以`onnx::Add`算子为例。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "from mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords, \\\n",
    "    WeightType\n",
    "from mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper\n",
    "from mindconverter.graph_based_converter.third_party_graph.base import NodeWeight\n",
    "\n",
    "\n",
    "class AddMapper(ONNXToMindSporeMapper):\n",
    "    \"\"\"Add mapper.\"\"\"\n",
    "\n",
    "    @staticmethod\n",
    "    def _operation_name_in_ms(*args, **kwargs):\n",
    "        return \"P.Add\"\n",
    "\n",
    "    @staticmethod\n",
    "    def _convert_params(**kwargs):\n",
    "        return dict()\n",
    "\n",
    "    @staticmethod\n",
    "    def _convert_trained_weights(**kwargs):\n",
    "        \"\"\"\n",
    "        权重迁移相关方法，返回数据体用于生成CheckPoint文件。\n",
    "\n",
    "        Returns, dict(MindSpore算子权重名: {'data': 权重值, 'type': 权重类型， 'onnx_name': ONNX算子权重名})\n",
    "        \"\"\"\n",
    "        weights = kwargs.get('weights', list())  # 获取算子输入当中的静态ensor数据体,即为该算子权重，保存在CheckPoint文件当中。\n",
    "        tensor = AddMapper._find_val_by_index(0, weights)  # 获取权重值，类型为`numpy.ndarray`。\n",
    "        onnx_name = AddMapper._find_onnx_name_by_index(0, weights)  # 获取权重在ONNX框架中的名称，主要用于权重共享相关功能。\n",
    "        # 仅当静态tensor为`np.ndarray`且存在`shape`信息时，该tensor会被保存为权重。\n",
    "        if isinstance(tensor, np.ndarray) and tensor.shape:\n",
    "            return {'bias': {'data': tensor, 'type': WeightType.PARAMETER.value, 'onnx_name': onnx_name}}\n",
    "        return dict()\n",
    "\n",
    "    @staticmethod\n",
    "    def _generate_snippet_template(**kwargs):\n",
    "        template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template(\n",
    "            **kwargs)\n",
    "        op = kwargs.get(\"operation\")\n",
    "        args = kwargs.get(\"converted_params\")\n",
    "        weights = kwargs.get(\"weights\")\n",
    "        trainable_params = kwargs.get('trainable_params', dict())  # 获取`_convert_trained_weights`方法的返回值。\n",
    "        if not op:\n",
    "            raise ValueError(\"Can not get MindSpore operation name.\")\n",
    "        if not weights:\n",
    "            return template, exchange_msg, outputs_list, outputs_mapping\n",
    "\n",
    "        tensor = AddMapper._find_val_by_index(0, weights)\n",
    "        bias_shape = tensor.shape\n",
    "        # 该静态Tensor在原ONNX算子中的输入中的位置序列号，例如：在算子`onnx::Add(x, y)`中，`x`的位置序列号为0，`y`的位置序列号为1。\n",
    "        bias_location = AddMapper._find_location_by_index(0, weights)\n",
    "\n",
    "        variable_slot = \"var_0\"\n",
    "        init_template = f\"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})\"\n",
    "        inputs_in_construct = [f\"{{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}\"]\n",
    "\n",
    "        # 使用该位置序列号信息，确保该静态Tensor在生成的MindSpore算子中的输入顺序和原ONNX算子中的输入顺序保持一致。\n",
    "        if bias_location != -1:\n",
    "            inputs_in_construct.insert(bias_location, f\"self.{{{variable_slot}}}_bias\")\n",
    "\n",
    "        # 构建出常量Tensor算子，作为算子的输入。\n",
    "        # `XXX/bias`和`XXX_bias`当中的`bias`需要\n",
    "        # 和`_convert_trained_weights`方法返回值当中定义的`bias`（MindSpore算子权重名）保持一致。\n",
    "        if bias_shape:\n",
    "            # Note: adding weight shape to args is now deprecated due to conflict of partial weights share processing.\n",
    "            variable_slot_param_name = f\"{variable_slot}/bias\"  # XX/bias`\n",
    "            init_tensor = f\"self.{{{variable_slot}}}_bias = {{{variable_slot_param_name}}}\"\n",
    "\n",
    "        else:\n",
    "            # 当`shape`信息为None时，`tensor.tolist()`返回单个数值，这种情况下，该值作为算子参数，构建出常量算子作为算子输入。\n",
    "            args[\"bias_value\"] = tensor.tolist()\n",
    "            init_tensor = f\"self.{{{variable_slot}}}_bias = {{bias_value}}\"\n",
    "\n",
    "        construct_template = f\"opt_{{{variable_slot}}} = self.{{{variable_slot}}}\" \\\n",
    "                             f\"({', '.join(inputs_in_construct)})\"\n",
    "        template = {\n",
    "            variable_slot: {\n",
    "                TemplateKeywords.INIT.value: [init_template, init_tensor],\n",
    "                TemplateKeywords.CONSTRUCT.value: [construct_template]\n",
    "            }\n",
    "        }\n",
    "        exchange_msg = {\n",
    "            variable_slot: {\n",
    "                ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op,\n",
    "                ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None,\n",
    "                ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value:\n",
    "                    ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value,\n",
    "                ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [],\n",
    "                ExchangeMessageKeywords.VariableScope.value.ARGS.value: args,\n",
    "                ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights,\n",
    "                ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params\n",
    "            }\n",
    "        }\n",
    "\n",
    "        # 权重共享相关。声明权重名称，权重值由后续模块添加。\n",
    "        if bias_shape:\n",
    "            exchange_msg[variable_slot][ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value] = {\n",
    "                \"bias\": \"\"\n",
    "            }\n",
    "        outputs_list = [f\"opt_{{{variable_slot}}}\"]\n",
    "        outputs_mapping = ((0, 0),)\n",
    "        return template, exchange_msg, outputs_list, outputs_mapping"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 验证权重迁移算子映射脚本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper\n",
    "from mindconverter.graph_based_converter.common.code_fragment import Fragment\n",
    "\n",
    "\n",
    "def test_mapper(onnx_info):\n",
    "    \"\"\"\n",
    "    Test mapper.\n",
    "\n",
    "    Args:\n",
    "        onnx_info (dict): Onnx operator_info. Struct is\n",
    "                                   {\n",
    "                                    'op_name': op_name,\n",
    "                                    'attributes': dict(),\n",
    "                                    'weights': [NodeWeight(), ...]\n",
    "                                   }\n",
    "    \"\"\"\n",
    "\n",
    "    template, exchange_msg, outputs_lists, outputs_mapping = \\\n",
    "        ONNXToMindSporeMapper.convert(onnx_info['op_name'],\n",
    "                                      onnx_info['attributes'],\n",
    "                                      onnx_info['weights'])\n",
    "\n",
    "    exchange_msg['var_0']['variable_name'] = 'self_defined_operator'\n",
    "    exchange_msg['var_0']['inputs'] = ['x']\n",
    "\n",
    "    trainable_params = exchange_msg['var_0']['trainable_params']\n",
    "    for weight_name, weight_inst in trainable_params.items():\n",
    "        weight = weight_inst['data']\n",
    "        weight_shape = weight.shape\n",
    "        weight_dtype = weight.dtype\n",
    "        exchange_msg['var_0']['parameters'][weight_name] = Fragment.create_parameter(weight_shape, weight_dtype)\n",
    "\n",
    "    fragment = Fragment(data_entity=exchange_msg, code_template=template, outputs=outputs_lists,\n",
    "                        outputs_mapping=outputs_mapping)\n",
    "\n",
    "    code = fragment()\n",
    "    init_code = code[0]\n",
    "    construct_code = code[1]\n",
    "    print('-'*30, 'init_code', '-'*30)\n",
    "    print('\\n'.join(init_code))\n",
    "    print('-'*30, 'construct_code', '-'*30)\n",
    "    print('\\n'.join(construct_code))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------------------------------ init_code ------------------------------\n",
      "self.self_defined_operator = P.Add()\n",
      "self.self_defined_operator_bias = Parameter(Tensor(np.random.uniform(0, 1, (1, 3, 224, 224)).astype(np.int64)), name=None)\n",
      "------------------------------ construct_code ------------------------------\n",
      "opt_self_defined_operator = self.self_defined_operator(x, self.self_defined_operator_bias)\n"
     ]
    }
   ],
   "source": [
    "onnx_operator_info = {'op_name': 'onnx::Add',\n",
    "                      'attributes': {},\n",
    "                      'weights': [NodeWeight(weight_name='onnx_bias',\n",
    "                                             weight_value=np.ones((1, 3, 224, 224), dtype=np.int),\n",
    "                                             weight_location=1)]}\n",
    "test_mapper(onnx_operator_info)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}